import scipy
import numpy as np
import sklearn.decomposition
import logging
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import pandas as pd
import matplotlib
import itertools

class MobilityPlot(object):
    """
    Plotting the Mobility of animals in each Period Type
    """
    colors = ['r', 'b', 'y', 'g', 'c', 'm', 'k']

    def __init__(self):
        try:
            #Standardize the data
            pass
        except Exception as e:
            logging.error('Generating BehaviorPlot failed! %s', e)

    @staticmethod
    def _zScore(data):
        """
        Calculate the z-score on the data and filter NaN (set zeros)
        """
        zdata = data.apply(scipy.stats.mstats.zscore)
        if zdata.isnull().any().any():
            logging.warning('Filtering NaN in standardized data!')
            zdata = zdata.apply(np.nan_to_num)

        return zdata

    @staticmethod
    def _getColors(ids):
        """
        Return a tuple containing (idx, ID, color)
        If ids is a list it will generate a element for each list element
        If ids is a number it will generate N tuples
        """
        try:
            _ = iter(ids)
            colors=zip(itertools.count(), ids, itertools.cycle(MobilityPlot.colors))
        except TypeError:
            colors=zip(itertools.count(), range(int(ids)), itertools.cycle(MobilityPlot.colors))
        return colors

    @staticmethod
    def _normalizeTrial(trial, X, Y):
        trial.reset_index(inplace=True)
        periods = [trial[trial['PeriodType'] == period] for period in ['Precue', 'Cue', 'Reward', 'ITI']]

        tX = []
        tY = []
        for idx, period in enumerate(periods):
            if period.empty:
                tX.append([])
                tY.append([])
            else:
                period.Time = period.Time - period.Time.iloc[0]
                period.Time = period.Time / period.Time.iloc[-1]
                tX.append((period.Time + float(idx)).tolist())
                tY.append(period['Velocity'].tolist())

        X.append(tX)
        Y.append(tY)

    @staticmethod
    def _importData(data):
        assert 'TrialCnt' in data.columns, 'Missing TrialCnt column!'
        assert 'PeriodType' in data.columns, 'Missing PeriodType column!'

        try:
            X = []
            Y = []
            G = data.groupby('TrialCnt')
            #Cannot use apply here because the function has side effects
            for TrialCnt, group in G:
                MobilityPlot._normalizeTrial(group, X, Y)
            return X, Y
        except Exception as e:
            logging.error('Importing data failed! %s', e)
        return data

    def plotAllTrials(self, data, title):
        assert 'ExtendedTrialType' in data.columns, 'Missing ExtendedTrialType column!'

        f, AX = plt.subplots(len(data.ExtendedTrialType.unique()), 1, sharex=True, sharey=True)
        for i, tt in enumerate(data.ExtendedTrialType.unique()):
            logging.debug('Processing %s', tt)
            self.plot(data[data['ExtendedTrialType'] == tt], title=tt, ax=AX[i])
        f.suptitle(title, fontsize=15)
        f.show()

    def plot(self, data, title, ylim=(0, 50), ax=None):
        """
        Generate a line plot showing the normalized mobility of an animal for
        each period type
        """
        X, Y = self._importData(data)
        if ax:
            f = None
        else:
            f, ax = plt.subplots()

        nPeriods = len(X[0])
        colors = list(self._getColors(nPeriods))
        interpol = []
        for tX, tY in zip(X, Y):
            interpol.append(scipy.interpolate.interp1d(np.concatenate(([0.0], np.concatenate(tX))), np.concatenate(([0.0], np.concatenate(tY)))))
            for x, y, c in zip(tX, tY, colors):
                ax.plot(x, y, color=c[2], alpha=0.3)

        N = 400
        intY = np.empty(shape=(len(X), N))
        intX = np.linspace(0.0, max(np.concatenate(X[0])), N)
        for i, intF in enumerate(interpol):
            intY[i] = intF(intX)

        mean = np.mean(intY, axis=0)
        std = np.std(intY, axis=0)

        ax.plot(intX, mean, color='black', alpha=1.0)
        ax.fill_between(intX, mean-std, mean+std, alpha=0.2, color='black')

        ax.set_xticklabels(['Precue', 'Cue', 'Reward', 'ITI'])
        ax.set_xticks([0.0+0.5, 1.0+0.5, 2.0+0.5, 3.0+0.5, 4.0+0.5])
        ax.set_xlim(0, 4)

        ax.set_ylabel('Mobility')

        ax.legend()

        ax.set_title(title)
        if f:
            f.show()

